package com.example.demo.util;

import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.ClassLoaderUtil;
import cn.hutool.core.util.StrUtil;
import com.example.demo.constant.EnumConstant;

import java.io.File;
import java.io.FileFilter;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * @author zhangfc
 * @date 2021/12/28 16:53
 */
public class ClassFactoryUtils {
    /**
     * 根据包名获取包下面所有的类名
     *
     * @param pack 包名
     * @return Set
     */
    public static Set<Class<?>> getClasses(String pack) {

        //第一个class类的集合
        Set<Class<?>> classes = new LinkedHashSet<>();
        //是否循环迭代
        boolean recursive = true;
        //获取包的名字 并进行替换
        String packageName = pack;
        String packageDirName = StrUtil.replace(packageName, ".", "/");
        //定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs;
        try {
            dirs = ClassLoaderUtil.getContextClassLoader().getResources(packageDirName);
            while (dirs.hasMoreElements()) {
                //获取dirs的每一个元素
                URL url = dirs.nextElement();
                //获取dirs的每一个元素的每个协议的名称
                String protocol = url.getProtocol();
                //如果是以文件的形式保存在服务器上
                if (StrUtil.equals(EnumConstant.UtilsConStant.C_FILE, protocol)) {
                    //获取包的物理路径
                    String filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                    //以文件的方式扫描整个包下的文件,并添加到集合中
                    findClassesInPackageByFile(packageName, filePath, recursive, classes);
                } else if (StrUtil.equals(EnumConstant.UtilsConStant.C_JAR, protocol)) {
                    //如果是jar包文件
                    //定义一个JarFile
                    JarFile jar;
                    try {
                        JarURLConnection jarURLConnection = (JarURLConnection) url.openConnection();
                        jar = jarURLConnection.getJarFile();
                        //从此jar包 得到一个枚举类
                        Enumeration<JarEntry> entries = jar.entries();
                        findClassesInPackageByJar(packageName, entries, packageDirName, recursive, classes);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return classes;
    }

    /**
     * 以jar的形式来获取包下的所有class
     *
     * @param packageName    原包名
     * @param entries        jar包所有的枚举
     * @param packageDirName 包的.替换成/的路径
     * @param recursive      是否迭代
     * @param classes        class集合
     */
    private static void findClassesInPackageByJar(String packageName, Enumeration<JarEntry> entries, String packageDirName, boolean recursive, Set<Class<?>> classes) {
        //迭代循环
        while (entries.hasMoreElements()) {
            //获取jar里的一个实体,可以是目录和一些jar包里的其他文件,如META-INF等文件
            JarEntry entry = entries.nextElement();
            String name = entry.getName();
            //如果以/开头的
            if (name.charAt(0) == '/') {
                //获取后面的字符串
                name = name.substring(1);
            }
            //如果前半部分和定义的包名相同
            if (name.startsWith(packageDirName)) {
                int idx = name.lastIndexOf('/');
                //如果以'/'结尾是一个包
                if (idx != -1) {
                    //获取包名 把'/'替换成'.'
                    packageName = name.substring(0, idx).replace('/', '.');
                }
                //如果可以迭代下去 并且是一个包
                if ((idx != -1) || recursive) {
                    //如果是一个 .class文件 而且不是目录
                    if (name.endsWith(".class") && !entry.isDirectory()) {
                        //去掉后面的".class"获取真正的类名
                        String className = name.substring(packageName.length() + 1, name.length() - 6);
                        //去掉类名中包含"$"的文件
                        if (StrUtil.contains(className, "$")) {
                            //继续循环
                            continue;
                        }
                        try {
                            classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + "." + className));
                        } catch (ClassNotFoundException e) {
                            e.printStackTrace();
                        }
                    }
                }
            }
        }
    }

    /**
     * 以文件的形式获取包下的所有class
     *
     * @param packageName 包名
     * @param filePath    文件路径
     * @param recursive   是否迭代循环
     * @param classes     class集合
     */
    private static void findClassesInPackageByFile(String packageName, String filePath, boolean recursive, Set<Class<?>> classes) {
        //获取此包的路径建立一个File
        File dir = new File(filePath);
        //如果不存在或者不是目录就直接返回
        if (!FileUtil.exist(dir) || !FileUtil.isDirectory(dir)) {
            //用户定义包名 packageName下没有任何文件
            return;
        }
        //如果存在就获取包下的所有文件,包括目录
        File[] dirFiles = dir.listFiles(new FileFilter() {
            @Override
            public boolean accept(File pathname) {
                return (recursive && pathname.isDirectory()) || (pathname.getName().endsWith(".class"));
            }
        });
        //循环所有文件
        for (File file : dirFiles) {
            //如果是目录 则继续扫描
            if (file.isDirectory()) {
                findClassesInPackageByFile(packageName + "." + file.getName()
                        , file.getAbsolutePath()
                        , recursive
                        , classes);
            } else {
                try {
                    //添加到集合中去
                    //如果是java类文件 去掉后面.class 只留下类名
                    String className = file.getName().substring(0, file.getName().length() - 6);
                    //去掉类名中包含"$"的文件
                    if (StrUtil.contains(className, "$")) {
                        //继续循环
                        continue;
                    }
                    classes.add(Thread.currentThread().getContextClassLoader().loadClass(packageName + "." + className));
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
        }

    }
}
